
"""create main"""

import pickle
import numpy as np
from result_collection import Result_collection, Run_results
from subspace_clustering import K_Subspaces_algo
from util import Dataset
from subspace_z_alg import Kz_Subspaces_algo
import util
import torch


def run_subspacel2_experiment():
    datasets = util.load_normalized_datasets([Dataset.TEST_DATA])
    k_values, n_runs, max_iter, j_values = get_params()
    for dataset, dataset_name in datasets:
        for j in j_values:
            if j >= dataset.shape[1]:
                    print("skipping")
                    break
            min_subset_size = int(np.log2(np.max(k_values)*j)) + 1
            max_subset_size = 13
            subset_sizes = [2 ** i for i in range(min_subset_size, max_subset_size)]
            res_collection = Result_collection(k_values, subset_sizes, n_runs, dataset_name, "K_Subspaces_l2",j=j)
            for k in k_values:
                print(f"dataset: {dataset_name}, j: {j}, k: {k}")
                alg = K_Subspaces_algo(k, j, max_iter=max_iter)
                res_collection = run_experiment(dataset, subset_sizes, n_runs, alg, k, res_collection)
            with open(f"results/results_{res_collection.get_file_name()}.pkl", "wb") as f:
                pickle.dump(res_collection, f)
            print(f"Finished {res_collection.get_file_name()}")

def get_params():
    k_values = [10,20,30,50]

    n_runs = 5
    
    max_iter = 40
    j_values = [1,2,5]
    return k_values,n_runs,max_iter,j_values
                
def run_experiment(dataset, subset_sizes, n_runs, algorithm, k, res_collection):
    print(f"getting original cost")
    centers_full, original_cost = algorithm.get_original_solution(dataset)
    for subset_size in subset_sizes:
        run_res = Run_results(k, subset_size, n_runs)
        run_res.set_centers_full(centers_full)
        run_res.set_originalcost(original_cost)
        for run in range(n_runs):
            print(f"subset_size: {subset_size}, run: {run}, k: {k}")
            subset = dataset[np.random.choice(dataset.shape[0], subset_size, replace=False), :]
            subset_cost, subset_solution_original, centers_subset = algorithm.get_subset_solution_original(dataset, subset)
            run_res.add_subset_cost(subset_cost)
            run_res.add_subset_solution_original(subset_solution_original)
            run_res.add_centers_subset(centers_subset)
        res_collection.add_result(run_res)
    return res_collection


def run_subspace_z():
    datasets = util.load_normalized_datasets([Dataset.TEST_DATA])
    k_values, n_runs, max_iter, j_values = get_params()
    z_values = [1,3,4]
    for dataset, dataset_name in datasets:
        for z in z_values:
            for j in j_values:
                min_subset = int(np.log2(np.max(k_values)*j)) + 1
                subset_sizes = [2**i for i in range(min_subset, 13)]
                if j >= dataset.shape[1]:
                    print("skipping")
                    break
                res_collection = Result_collection(k_values, subset_sizes, n_runs, dataset_name, util.Algorithm.K_Z_SUBSPACE.value,j=j, z=z)
                for k in k_values:
                    print(f"dataset: {dataset_name}, j: {j}, k: {k}")
                    alg = Kz_Subspaces_algo(k,z, j,max_sgd=100, max_iter=max_iter)
                    res_collection = run_experiment(dataset, subset_sizes, n_runs, alg, k, res_collection)
                if torch.cuda.is_available():
                    with open(f"/content/drive/MyDrive/subspace_implementation/results/results_{res_collection.get_file_name()}.pkl", "wb") as f:
                        pickle.dump(res_collection, f)
                    print(f"Finished {res_collection.get_file_name()}") 
                else:       
                    with open(f"results/results_{res_collection.get_file_name()}.pkl", "wb") as f:
                        pickle.dump(res_collection, f)
                    print(f"Finished {res_collection.get_file_name()}")

"""create main"""
if __name__ == "__main__":
    run_subspace_z()
    run_subspacel2_experiment()